/**
 * Java implementation of an AVL Tree of type T.
 * 
 * @author Jonathan E. Landrum <me@jonlandrum.com>
 * @since 2014-08-05
 */

package com.jonlandrum.selfbalancingtree.avltree;

import com.jonlandrum.selfbalancingtree.Node;
import com.jonlandrum.selfbalancingtree.SelfBalancingTree;

public class AVLTree<T extends Comparable<T>> extends SelfBalancingTree<T> {
    /*
     * Constructors
     */
    public AVLTree() {
        super();
    }
    
    public AVLTree(T d) {
        super(d);
    }
    
    public AVLTree(Node<T> n) {
        super(n);
    }
    
    /*
     * Operations
     */
    @Override
    public Node<T> addElement(T d) {
        return this.addElement(new Node<T>(d));
    }
    
    @Override
    public Node<T> addElement(Node<T> n) {
        Node<T> result = super.addElement(n);
        if (result.isEmpty()) { return result; }
        Node<T> temp = result;
        while (temp != this.getRoot()) {
            if (temp.isRightChild()) {
                temp.getParent().setRightHeight(temp.getHeight() + 1);
            } else if (temp.isLeftChild()) {
                temp.getParent().setLeftHeight(temp.getHeight() + 1);
            }
            temp = temp.getParent();
        }
        this.balance(temp);
        return result;
    }
    
    @Override
    public Node<T> removeElement(T d) {
        return this.removeElement(new Node<T>(d));
    }
    
    @Override
    public Node<T> removeElement(Node<T> n) {
        Node<T> result = super.findElement(n);
        if (result.isEmpty()) { return result; }
        Node<T> balanceNode = new Node<T>(), successor = new Node<T>();
        if (result.hasRightChild()) {
            successor = result.getRightChild();
            while (successor.hasLeftChild()) {
                successor = successor.getLeftChild();
            }
            if (successor.isLeftChild()) {
                balanceNode = successor.getParent();
                successor = successor.getParent();
                successor.setLeftHeight(0);
                while (successor.isLeftChild()) {
                    successor = successor.getParent();
                    successor.setLeftHeight(successor.getLeftChild().getHeight() + 1);
                }
            } else {
                balanceNode = successor;
            }
        } else if (result.hasLeftChild()) {
            balanceNode = result.getLeftChild();
        }
        result = super.removeElement(n);
        if (result.hasLeftChild()) {
            result.setLeftHeight(result.getLeftChild().getHeight() + 1);
        } else {
            result.setLeftHeight(0);
        }
        if (result.hasRightChild()) {
            result.setRightHeight(result.getRightChild().getHeight() + 1);
        } else {
            result.setRightHeight(0);
        }
        if (result.isLeftChild()) {
            result.getParent().setLeftHeight(result.getHeight() + 1);
        } else if (result.isRightChild()) {
            result.getParent().setRightHeight(result.getHeight() + 1);
        }
        if (!balanceNode.isEmpty()) {
            this.balance(balanceNode);
        }
        return result;
    }
    
    protected void balance(Node<T> current) {
        do {
            if (Math.abs(current.getBalance()) > 1) {
                if (current.getBalance() > 1) {
                    if (current.getRightChild().getBalance() == 1) {
                        this.rotateLeft(current);
                    } else {
                        this.rotateRight(current.getRightChild());
                        this.rotateLeft(current);
                    }
                } else if (current.getBalance() < -1) {
                    if (current.getLeftChild().getBalance() == 1) {
                        this.rotateLeft(current.getLeftChild());
                        this.rotateRight(current);
                    } else {
                        this.rotateRight(current);
                    }
                }
            }
            if (current.hasParent()) {
                current = current.getParent();
            }
        } while (current != this.getRoot());
    }
    
    @Override
    protected void rotateLeft(Node<T> n) {
        super.rotateLeft(n);
        if (n.hasRightChild()) {
            n.setRightHeight(n.getRightChild().getHeight() + 1);
        } else {
            n.setRightHeight(0);
        }
        n.getParent().setLeftHeight(n.getHeight() + 1);
    }
    
    @Override
    protected void rotateRight(Node<T> n) {
        super.rotateRight(n);
        if (n.hasLeftChild()) {
            n.setLeftHeight(n.getLeftChild().getHeight() + 1);
        } else {
            n.setLeftHeight(0);
        }
        n.getParent().setRightHeight(n.getHeight() + 1);
    }
}